%  通过连接获得的别人写的代码，具体的连接地址：https://stevegilham.blogspot.com/2008/10/first-refactoring-of-star-in-erlang.html

-module(a_star).
-export([main/0]).
%% 其实这个算法也是有一些问题的，比如：在X=7时，0已经出现在了open列表中的，本可以结束了，但是还是多执行了一次

%% A* 算法，会包括两个集合，Closeset是已经验证过的节点，Openset是待评估的节点
astar(Start,Goal) ->
  Closedset = sets:new(), % The set of nodes already evaluated.
  Openset = sets:add_element(Start,sets:new()), %The set of tentative nodes to be evaluated

  Fscore = dict:append(Start, h_score(Start), dict:new()),   % 计算f值；计算start对应的h值;初始时g值为0,只有f值
  Gscore = dict:append(Start, 0, dict:new()), % Distance from start along optimal path. % 默认start节点的g值是0

  CameFrom = dict:append(Start, none, dict:new()),  % 记录的应该是走的路径
  astar_step(Goal, Closedset, Openset, Fscore, Gscore, CameFrom).

%% 试着理解一下，每次都计算一下
astar_step(Goal, Closedset, Openset, Fscore, Gscore, CameFrom) ->
  case sets:size(Openset) of
	0 ->
	  failure;
	_ ->
		OpenSetList = sets:to_list(Openset),
		case lists:member(Goal,OpenSetList) of
			true ->
				NodeList = reconstruct_path(CameFrom, Goal),
				io:format("CameFrom NodeList . NodeList = ~p \n", [NodeList]),
				NodeList;
			false ->
				BestNode = best_step(sets:to_list(Openset), Fscore, none, infinity),
				io:format("best step . X = ~p \n", [BestNode]),
				NextOpen = sets:del_element(BestNode, Openset),
				NextClosed = sets:add_element(BestNode, Closedset),
				Neighbours = neighbour_nodes(BestNode),
				{NewOpen, NewF, NewG, NewFrom} = scan(BestNode, Neighbours, NextOpen, NextClosed, Fscore, Gscore, CameFrom),
				io:format("now openset = ~p \n", [sets:to_list(NewOpen)]),
				astar_step(Goal, NextClosed, NewOpen, NewF, NewG, NewFrom)
		end
  end.

scan(_X, [], Open, _Closed, F, G, From) ->
  {Open, F, G, From};
scan(X, [Y|N], Open, Closed, F, G, From) ->
  case sets:is_element(Y, Closed) of
	true ->
		scan(X, N, Open, Closed, F, G, From);
	false ->
	  [G0] = dict:fetch(X, G),
	  TrialG = G0 + dist_between(X,Y),
	  case sets:is_element(Y, Open) of
		true ->                       % 如果这个节点的周边节点已经在openset中了，则需要特殊特殊
		  [OldG] = dict:fetch(Y, G),
		  case TrialG < OldG of     % 如果这个节点比旧的小，则添加该节点到openset
			true ->
			  {NewF, NewG, NewFrom} = update(X, Y, F, G, From, TrialG),
			  scan(X, N, Open, Closed, NewF, NewG, NewFrom);
			false ->   % 如果大了，则已经有小的节点，不添加这个
				scan(X, N, Open, Closed, F, G, From)
		  end;
		false ->
		  NewOpen = sets:add_element(Y, Open),
		  {NewF, NewG, NewFrom} = update(X, Y, F, G, From, TrialG),
		  scan(X, N, NewOpen, Closed, NewF, NewG, NewFrom)
	  end
  end.

%% 更新Y的信息
update(X, Y, OldF, OldG, OldFrom, GValue) ->
  KeyF = dict:is_key(Y, OldF),
  KeyG = dict:is_key(Y, OldG),
  KeyFrom = dict:is_key(Y, OldFrom),
  case {KeyF, KeyG, KeyFrom} of      % 先循环删除旧值，再添加新值
	{true, _, _} ->
	  update(X, Y, dict:erase(Y, OldF), OldG, OldFrom, GValue);  % 尾调用的方法
	{_, true, _} ->
	  update(X, Y, OldF, dict:erase(Y, OldG), OldFrom, GValue);
	{_, _, true} ->
	  update(X, Y, OldF, OldG, dict:erase(Y, OldFrom), GValue);
	_ ->
	  NewFrom = dict:append(Y, X, OldFrom),  % 保存的是路径链
	  NewG = dict:append(Y, GValue, OldG),   % 保存Y的G值到OldG中
	  NewF = dict:append(Y, GValue + h_score(Y), OldF), % Estimated total distance from start to goal through y.
	  {NewF, NewG, NewFrom}
  end.

reconstruct_path(CameFrom, Node) ->
  case dict:fetch(Node, CameFrom) of
	[none] ->
	  [Node];
	[Value] ->
	  [Node | reconstruct_path(CameFrom, Value)]
  end.

best_step([H|Open], Score, none, _) ->
  [V] = dict:fetch(H, Score),
  best_step(Open, Score, H, V);
best_step([], _Score, Best, _BestValue) ->
  Best;
best_step([H|Open], Score, Best, BestValue) ->
  [Value] = dict:fetch(H, Score),
  case Value < BestValue of
	true ->
	  best_step(Open, Score, H, Value);
	false ->
	  best_step(Open, Score, Best, BestValue)
  end.

%% specialize for the torch-carrier problem
%% bits 0-4 represent torch, 1m, 2m, 5m, 10m
%% set bit = on the near side

%% Every possible crossing of one or two
swaps() ->
  [3,5,9,17, 7, 11,13, 19,21,25].

crossing_time(Swap) ->
  if
	Swap band 16 > 0 ->
	  10;
	Swap band 8 > 0 ->
	  5;
	Swap band 4 > 0 ->
	  2;
	true ->
	  1
  end.

%% presentation form
display(Swap) ->
  if
	Swap band 16 > 0 ->
	  compound(Swap, 16);
	Swap band 8 > 0 ->
	  compound(Swap, 8);
	Swap band 4 > 0 ->
	  compound(Swap, 4);
	Swap band 2 > 0 ->
	  compound(Swap, 2);
	true ->
	  ""
  end.

compound(Swap, Bit) ->
  string:concat( erlang:integer_to_list(crossing_time(Swap)),
	decorate(display(Swap bxor Bit))).

decorate(Value) ->
  case string:len(Value) of
	0 ->
	  Value;
	_ ->
	  string:concat("+", Value)
  end.

%% 寻找周边的节点
neighbour_nodes(X) ->
  Result = [],
  compatible(X, equivalent_point(X), Result, swaps()).

equivalent_point(X) ->
  case X band 1 of   % 与运算
	1 ->
		X;
	0 ->
	  31 bxor X    % 异或运算
  end.

compatible(_X, _Y, Outlist, []) ->
  Outlist;
compatible(X, Y, Outlist, [Swap|Inlist]) ->
  case (Y band Swap) of
	Swap ->
	  New = X bxor Swap,
	  compatible(X, Y, [New|Outlist], Inlist);
	_ ->
	  compatible(X, Y, Outlist, Inlist)
  end.

dist_between(X,Y) ->
  Swap = X bxor Y,
  crossing_time(Swap).

%% 计算h值
h_score(Node) ->
  crossing_time(Node).

main() ->
	case astar(31,0) of
		failure ->
			io:format("[Error]not find route...\n");
		RouteList ->
			[H|Trace] = lists:reverse(RouteList),
			Time = print_result(Trace, H, 0),
			io:format("Time taken = ~B minutes~n", [Time])
	end.

print_result([], _Prev, Time) ->
  Time;
print_result([H|Trace], Prev, Time) ->
  Swap = H bxor Prev,
  print_swap(Swap, H band 1),
  Interval = Time + crossing_time(Swap),
  print_result(Trace, H, Interval).

print_swap(Swap, Side) ->
  case Side of
	0 ->
	  io:format(  "    ~s -->;~n", [display(Swap)]);
	1 ->
	  io:format(  "<-- ~s~n", [display(Swap)])
  end.